"""
网络模型的保存与读取
"""
import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)

# 方式1 保存模型结构+参数
# torch.save(vgg16, "vgg16_method1.pth")
# 加载
# model = torch.load("vgg16_method1.pth")
# print(model)

# 方式2 保存模型参数
# torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# 加载
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load("vgg16_method2.pth"))
print(model)

# notes：自定义模型导入时必须要能访问到定义网络结构的文件

